package org.bycloud.common.base;

import java.sql.DatabaseMetaData;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.bycloud.common.Pager;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.jdbc.core.BeanPropertyRowMapper;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.stereotype.Component;

import lombok.Getter;

@Component
@Getter
public class SimpleJdbcDao<M> extends BaseJdbcDao<M> implements InitializingBean {

    private static final String ORACLE_LIMIT_TEMPLATE =
            "select * from (select row_.*, rownum rownum_ from ({placeholder}) row_ where rownum <= ? ) where rownum_ > ?";
    private static final String SQLSERVER_LIMIT_TEMPLATE =
            "SELECT * FROM (SELECT ROW_NUMBER() OVER (order by (select 0)) AS rownum, {main} ) rs WHERE rs.rownum > ? and rs.rownum <= ?";
    private static Pattern fromClausePtn =
            Pattern.compile("(?i)\\*\\s+from\\s+([\\S]+)|(?:(\\()|(\\))\\s*([\\S]+))");

    boolean isMySql;
    boolean isOracle;
    boolean isSqlServer;
    boolean isDB2;
    protected final Object[] NULL_PARA_ARRAY = new Object[0];
    protected final Set<String> emptySet = new HashSet<String>();

    public String getOracleLimitSQL(String baseSQL) {
        return ORACLE_LIMIT_TEMPLATE.replace("{placeholder}", baseSQL);
    }

    private Pattern sqlServerSelectPtn = Pattern.compile("(?i)select");

    public String getSqlServerLimitSQL(String baseSQL) {
        baseSQL = sqlServerSelectPtn.matcher(baseSQL).replaceFirst("");
        String tpl = SQLSERVER_LIMIT_TEMPLATE;
        if (isDB2()) {
            tpl = tpl.replace("ROW_NUMBER", "ROWNUMBER");
            baseSQL = addSubalias(baseSQL);
        }
        int orderByIdx = baseSQL.indexOf("order by");
        if (orderByIdx == -1) orderByIdx = baseSQL.indexOf("ORDER BY");
        String mainSql = baseSQL;
        if (orderByIdx > 0) {
            mainSql = baseSQL.substring(0, orderByIdx);
            String orderSql = baseSQL.substring(orderByIdx);
            tpl = tpl.replace("order by (select 0)", orderSql);
        } else if (isDB2()) {
            tpl = tpl.replace("select 0", "select 1 from sysibm.sysdummy1");
        }
        tpl = tpl.replace("{main}", mainSql);
        return tpl;
    }

    private static String addSubalias(String basesql) {
        Matcher matcher = fromClausePtn.matcher(basesql);
        int subalias = 0;
        StringBuffer b = new StringBuffer();
        int subSelectLevel = 0;
        ArrayList<String> subaliasList = new ArrayList<String>();
        while (matcher.find()) {
            if (matcher.group(1) != null) {
                if (matcher.group(1).contains("(")) {
                    if (subSelectLevel == 0) subSelectLevel = 1;
                    matcher.appendReplacement(b, "\5_" + subalias + "." + matcher.group());
                } else {
                    matcher.appendReplacement(b, matcher.group(1) + "." + matcher.group());
                }
            } else if (subSelectLevel > 0) {
                matcher.appendReplacement(b, matcher.group());
                if (matcher.group(2) != null) {// (
                    subSelectLevel++;
                } else if (matcher.group(3) != null || matcher.group(4) != null) {
                    subSelectLevel--;
                    if (subSelectLevel == 0) {
                        if (matcher.group(4) != null) {
                            subaliasList.add(matcher.group(4));
                        }
                    }
                }
            }
        }
        matcher.appendTail(b);
        String sql = b.toString();
        int i = 0;
        for (String a : subaliasList) {
            sql = sql.replace("\5_" + i, a);
            i++;
        }
        return sql;
    }

    public List<M> queryPagingSql(String sql, int start, int limit, Object[] args,
            RowMapper<M> rowmapper) {
        if (start != -1 && limit != -1) {
            if (this.isOracle()) {
                sql = getOracleLimitSQL(sql);
                args = ArrayUtils.addAll(args, new Object[] {start + limit, start});
            } else if (this.isSqlServer() || this.isDB2()) {
                sql = getSqlServerLimitSQL(sql);
                args = ArrayUtils.addAll(args, new Object[] {start, start + limit});
            } else {
                sql += " limit ?, ?";
                args = ArrayUtils.addAll(args, new Object[] {start, limit});
            }
            return jdbcTemplate.query(sql, args, rowmapper);
        }
        return Collections.emptyList();
    }


    @Override
    public void afterPropertiesSet() throws Exception {
        DatabaseMetaData md = jdbcTemplate.getDataSource().getConnection().getMetaData();
        String dialect = md.getDatabaseProductName();
        if (dialect.indexOf("mysql") != -1)
            isMySql = true;
        else if (dialect.indexOf("oracle") != -1)
            isOracle = true;
        else if (dialect.indexOf("sqlserver") != -1)
            isSqlServer = true;
        else if (dialect.indexOf("db2") != -1) {
            isDB2 = true;
            isSqlServer = true;
        }
    }


    public Pager<M> findPager(int start, int limit, String sql, Object... args) {
        String fromHql = "from " + StringUtils.substringAfter(sql, "from");
        fromHql = StringUtils.substringBefore(fromHql, "order by");
        int count = findCount("select count(*) " + fromHql, args);
        List<M> data =
                queryPagingSql(sql, start, limit, args, new BeanPropertyRowMapper<M>(entityClass));
        return Pager.getResult(count, data);
    }

}
